#!/usr/bin/env python3

import os
import pickle as pkl
import pprint as pp
import matplotlib.pyplot as plt
import numpy as np
import time
import subprocess
import argparse
import xlrd
import random
from enum import Enum


class RunStatus(Enum):
    OK = 0
    TLE = 1
    WA = 2
    RE = 3


total_finished = 0


class ResultInfo:
    def __init__(self, status: RunStatus, num: int, size_list = None, is_survive = None):
        self.status = status
        self.num = num
        self.size_list = size_list
        self.survive_list = is_survive

    def __str__(self):
        return "#query: " + str(self.num)

    def __lt__(self, other):

        return self.num < other.num

    def checkError(self, threshold):
        num = 0
        for i in range(len(self.survive_list)):
            if self.survive_list[i] == 1:
                num += 1
                if num == threshold:
                    for j in range(i, len(self.survive_list)):
                        if self.survive_list[j] == 0:
                            return False
                    return True
            else:
                num = 0
        return True

    def getNum(self, threshold, name):
        if "Eps" not in name:
            return max(len(self.survive_list), 1)
        num = 0
        ans = 0
        for i in range(len(self.survive_list)):
            ans += 1
            if self.survive_list[i] == 1:
                num += 1
                if num == threshold:
                    return ans
            else:
                num = 0
        return max(ans, 1)


solver_list_for_exp = {
    0: ["RandomSy", "SampleSy", "SampleIn0.1", "SampleDe", "SampleSy2", "SampleSy20", "SampleSy5000", "EpsSy"],
    1: ["RandomSy", "SampleSy", "EpsSy"],
    2: ["SampleSy", "SampleIn0.1", "SampleDe"],
    3: ["SampleSy2", "SampleSy20", "SampleSy5000"],
    4: ["EpsSy"]
}

benchmark_folder = {
    "string": "../benchmarks/string/final",
    "repair": "../benchmarks/repair"

}

depth_info = {}
data = xlrd.open_workbook("benchmark.xls")
for d_type, fold in benchmark_folder.items():
    table = data.sheet_by_name(d_type)
    depth_info[d_type] = {}
    nrows = table.nrows
    for i in range(nrows):
        benchmark_name, depth = table.row_values(i)
        depth = int(depth)
        depth_info[d_type][benchmark_name] = depth


def collect_all_benchmark():
    result = []
    for b_type, floder in benchmark_folder.items():
        files = list(filter(lambda x: ".sl" in x, os.listdir(floder)))
        for benchmark in files:
            if depth_info[b_type][benchmark] == -1:
                continue
            depth = depth_info[b_type][benchmark]
            for solver in solver_list:
                info = [b_type, benchmark, solver]
                name = "_".join(info)
                command = ['ulimit -v ' + str(memory_limit) + ';' + "timeout " + str(
                    time_limit) + " ./../build/main"] + info + [name + ".log", name + ".out", str(depth)]
                result.append([name, command])
    return result


result = {}
if os.path.exists("result.pkl"):
    with open("result.pkl", "rb") as inp:
        result = pkl.load(inp)


def update():
    with open("result.pkl", "wb") as oup:
        pkl.dump(result, oup)


def get_result(file_name, is_timeout):
    try:
        with open(file_name, "r") as inp:
            lines = inp.readlines()
            assert len(lines) >= 4
            status = lines[0][:-1]
            number = int(lines[1][:-1])
            size_list = list(map(float, lines[2][:-1].split(" ")))
            assert len(size_list) == abs(number) + 1
            is_survive = []
            if abs(number) > 0:
                is_survive = list(map(int, lines[3][:-1].split(" ")))
            assert len(is_survive) == abs(number)
            if status == "Unfinished":
                if is_timeout:
                    return ResultInfo(RunStatus.TLE, number, size_list, is_survive)
                else:
                    return ResultInfo(RunStatus.RE, number, size_list, is_survive)
            elif number < 0:
                return ResultInfo(RunStatus.WA, -number, size_list, is_survive)
            else:
                return ResultInfo(RunStatus.OK, number, size_list, is_survive)
    except FileNotFoundError:
        return ResultInfo(RunStatus.TLE if is_timeout else RunStatus.RE, 0, [1], [])


def finish(pos, status):
    assert thread_pool[pos] is not None
    name = thread_pool[pos]["name"]
    if name not in result:
        result[name] = []
    result[name].append(status)
    global total_finished
    total_finished += 1
    print("Finish ", name, "Now " + str(total_finished) + "/" + str(total_benchmark))
    update()


def deal_with(pos):
    timeout = time_limit
    status = thread_pool[pos]["thread"].poll()
    second_passed = time.time() - thread_pool[pos]["start_time"]
    if status is not None and status != 1:
        finish(pos, get_result(thread_pool[pos]["oup"], second_passed > timeout))
        thread_pool[pos] = None
        return
    if second_passed > timeout:
        thread_pool[pos]["thread"].terminate()
        finish(pos, get_result(thread_pool[pos]["oup"], True))
        thread_pool[pos] = None


def run_benchmark(name, command, ti):
    target_pos = None
    while target_pos is None:
        for pos in range(thread_num):
            if thread_pool[pos] is not None:
                deal_with(pos)
            if thread_pool[pos] is None:
                target_pos = pos
                break
        clear()
        time.sleep(1)
    print("Start ", name)
    new_command = []
    for now in command: new_command.append(now)
    new_command[5] += str(random.randint(0, 10 ** 9))
    cm = " ".join(new_command)
    if ti != time_limit:
        cm = cm.replace("timeout " + str(time_limit), "timeout " + str(ti))
    thread_pool[pos] = {}
    thread_pool[pos]["name"] = name
    thread_pool[pos]["oup"] = "temp/" + new_command[5]
    thread_pool[pos]["log"] = "log/" + new_command[4]
    thread_pool[pos]["start_time"] = time.time()
    thread_pool[pos]["thread"] = subprocess.Popen(cm, stdout=subprocess.PIPE, shell=True)


def clear():
    temp = os.popen("ps -af | grep FlashFill").read()
    threads = temp.split("\n")
    for thread in threads:
        info = thread.split(' ')
        info = list(filter(lambda x: len(x) > 0, info))
        if len(info) < 6: continue
        pid = int(info[1])
        mi = 0
        for data in info:
            if len(data) == 8 and data[2] == ":" and data[5] == ":":
                mi = int(data[3:5])
        if mi * 60 > time_limit:
            print("Kill " + str(pid))
            os.system("kill " + str(pid))


def collect_all_result():
    while True:
        all_empty = True
        for pos in range(thread_num):
            if thread_pool[pos] is not None:
                deal_with(pos)
                all_empty = False
        if all_empty:
            break
        clear()
        time.sleep(1)


def get_info(full_name):
    items = full_name.split("_")
    return "_".join(items[:-1]), items[-1]


def get_average(value, name):
    total = 0
    num = 0
    result_list = []
    for w in value:
        if w.status == RunStatus.OK:
            total += w.getNum(5, name)
            result_list.append(w.getNum(5, name))
            num += 1
    if num == 0:
        for w in value:
            total += w.getNum(5, name)
            result_list.append(w.getNum(5, name))
            num += 1
        if num == 0:
            return None, []
        return ResultInfo(RunStatus.RE, 1.0 * total / num), result_list
    else:
        return ResultInfo(RunStatus.OK, 1.0 * total / num), result_list


def draw_exp123(solver_list, name_map, restore_file, compare_ratio, data_name_map = lambda x: x):
    benchmark_result = {}
    average_list = []
    for key, value in result.items():
        benchmark_name, solver = get_info(key)
        if solver not in solver_list:
            continue
        assert solver in solver_list
        average, value_list = get_average(value, key)
        average_list.append([key, average])
        if average is None: continue
        if benchmark_name not in benchmark_result:
            benchmark_result[benchmark_name] = {}
        benchmark_result[benchmark_name][solver] = (average, value_list)
    dataset_list = ["repair", "string"]
    pos = [211, 212]
    average_list = sorted(average_list)
    #for key, value in average_list:
    #    print(key, value)

    title_list = ["Repair Track", "String Track"]
    where = -1

    for dataset_name in dataset_list:
        where += 1
        max_key_list = []
        for key, value in benchmark_result.items():
            if len(value) != len(solver_list): continue
            if dataset_name not in key: continue
            max_for_solvers = 0
            for i in range(len(solver_list)):
                max_for_solvers = max(max_for_solvers, value[solver_list[i]][0].num)
            max_key_list.append([max_for_solvers, key])
        max_key_list = sorted(max_key_list)

        for alpha in compare_ratio:
            num = int(alpha * len(max_key_list))
            average_ratio = {}
            count = 0
            for _, key in max_key_list[-num:]:
                value = benchmark_result[key]
                count += 1
                for i in range(len(solver_list)):
                    for j in range(len(solver_list)):
                        skey = solver_list[i] + ":" + solver_list[j]
                        if skey not in average_ratio:
                            average_ratio[skey] = 1
                        average_ratio[skey] *= value[solver_list[i]][0].num / value[solver_list[j]][0].num
            print("For the hardest {0}% benchmarks in".format(alpha * 100), title_list[where])
            for key in average_ratio:
                average_ratio[key] = average_ratio[key] ** (1.0 / count)
                if average_ratio[key] > 1 + 10 ** (-9):
                    l, r = key.split(":")
                    print(data_name_map(l), "takes", str((average_ratio[key] - 1) * 100) + "%", "more queries than", data_name_map(r), "in average")
    
    where = -1
    for dataset_name in dataset_list:
        where += 1
        plt.subplot(pos[where])
        solver_result = {}
        for name in solver_list:
            solver_result[name] = []
        total = 0
        for key, value in benchmark_result.items():
            if dataset_name not in key:
                continue
            total += 1
            if len(value) != len(solver_list):
                continue
            output_list = {}
            for solver_name, (cost, result_list) in value.items():
                solver_result[solver_name].append(cost)
                output_list[solver_name] = result_list

        n = 0
        for name in solver_list:
            solver_result[name] = sorted(solver_result[name])
            n = len(solver_result[name])

        ls_list = ['--', '-', '-.', ':', '-.']
        lc = ['orange', 'indianred', 'cornflowerblue']
        pc = ['peru', 'darkred', 'mediumblue']
        name_list = []
        for i in range(len(solver_list)):
            name = solver_list[i]
            name_list.append(name_map(name))
            x = []
            y = []
            for j, info in enumerate(solver_result[name]):
                x.append(j + 1)
                y.append(info.num)
            ls = ls_list[i]

            X = []
            Y = []
            step = len(x) // 15
            if step == 0:
                step = 1
            for j in range(len(x)):
                if (j % step == 0 and j > 0) or j + 1 == len(x):
                    X.append(x[j])
                    Y.append(y[j])

            plt.scatter(np.array(X), np.array(Y), marker='^', s=10, color=pc[i])
            plt.plot(np.array(x), np.array(y), ls=ls, color=lc[i], alpha=0.7)
        plt.title(title_list[where])
        plt.legend(name_list)
        plt.ylabel('#Queries')
        if where == 1:
            plt.xlabel('#Benchmarks')
    plt.tight_layout()
    plt.savefig("figure/" + restore_file)
    plt.cla()
    plt.figure()

def draw_4(restore_file):
    where = -1
    ax = []
    ax.append(plt.subplot(121))
    ax.append(ax[0].twinx())
    ax.append(plt.subplot(122))
    ax.append(ax[2].twinx())
    ax[0].set_title('Repair Track')
    ax[2].set_title('String Track')
    ax[0].set_xlabel('Threshold')
    ax[0].set_ylabel('Error rate')
    ax[2].set_xlabel('Threshold')
    ax[3].set_ylabel('#Queries')
    ax = [[ax[0], ax[1]], [ax[2], ax[3]]]
    lc = ['orange', 'indianred', 'cornflowerblue']
    pc = ['peru', 'darkred', 'mediumblue']
    dataset_list = ["repair", "string"]
    pos = [211, 212]
    N = 6
    for d_name in dataset_list:
        where += 1
        error_num = [0.0 for i in range(N)]
        average = [0.0 for i in range(N)]
        for key, value in result.items():
            benchmark_name, solver = get_info(key)
            if d_name not in key or "EpsSy" not in key: continue
            for w in value:
                error_num[0] += 1
                for i in range(1, N):
                    if not w.checkError(i):
                        error_num[i] += 1
                    average[i] += w.getNum(i, key)
        count = 0.0
        total_sample = 0.0
        for key, value in result.items():
            benchmark_name, sovler = get_info(key)
            if d_name not in key or "SampleSy" not in key: continue
            for w in value:
                count += 1
                total_sample += w.num
        total = error_num[0]
        for i in range(N):
            error_num[i] = 1.0 * error_num[i] / total
            average[i] = 1.0 * average[i] / total
        X = list(range(N))
        ax[where][0].plot(np.array(X), np.array(error_num), ls="-", color="orange", alpha=0.7, marker = "^", markerfacecolor="peru", label="error rate")
        for i in range(N - 1):
            ax[where][0].text(i, error_num[i], '%.1f' % (error_num[i]* 100) + '%', ha = 'left', va = 'bottom', fontsize=9)
        ax[where][0].text(N-1, error_num[N-1], '%.1f' % (error_num[N-1] * 100) + '%', ha = 'right', va = 'top', fontsize=9)
        ax[where][1].plot(np.array(X), np.array(average), ls = "-", color="cornflowerblue", alpha = 0.7, marker = "o", markerfacecolor = "mediumblue", label="query length")
        for i in range(N - 1):
            ax[where][1].text(i + 0.1, average[i], '%.1f' % average[i], ha = 'left', va = 'top', fontsize = 9)
        ax[where][1].text(N-1, average[N-1], '%.1f' % average[N-1], ha = 'right', va = 'bottom', fontsize=9)
        lines, labels = ax[where][0].get_legend_handles_labels()
        lines2, labels2 = ax[where][1].get_legend_handles_labels()
        ax[where][1].legend(lines + lines2, labels + labels2, loc=7)
    plt.tight_layout()
    plt.savefig("figure/" + restore_file)
    plt.cla()

def draw(exp_id):
    if not os.path.exists("figure"):
        os.system("mkdir figure")
    print("For experiment {0}:".format(exp_id))
    if exp_id == 1:
        draw_exp123(["SampleSy", "RandomSy", "EpsSy"], lambda x: x, "exp1.png", [0.3, 1])
    elif exp_id == 2:
        name_map = {"SampleDe": "weakened ps", "SampleSy": "ps", "SampleIn0.1": "enhanced ps"}
        draw_exp123(["SampleDe", "SampleSy", "SampleIn0.1"], lambda x: name_map[x], "exp2.png", [])
    elif exp_id == 3:
        draw_exp123(["SampleSy2", "SampleSy20", "SampleSy5000"], lambda x: "w = " + x[8:], "exp3.png", [0.3, 1], lambda x: "S(" + x[8:] + ")")
    elif exp_id == 4:
        draw_4("exp4.png")
    print()


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-tn', '--threadnum', type=int, default=8)
    parser.add_argument('-m', '--memory', type=int, default=6)
    parser.add_argument('-exp', '--experiment', type=int, choices=[0, 1, 2, 3, 4], default=0)
    parser.add_argument('-r', '--repeatnum', type=int, default=5)
    parser.add_argument('-c', '--cache', type=str, default="Continue", choices=["Restart", "Continue", "R", "C"])
    return parser.parse_args()
    

if __name__ == "__main__":
    if not os.path.exists("log"):
        os.system("mkdir log")
    if not os.path.exists("temp"):
        os.system("mkdir temp")

    args = parse_args()
    if args.cache[0] == 'R':
        os.system("rm result.pkl")
    global time_limit, solver_list, thread_num, thread_pool, memory_limit, total_benchmark
    thread_num = args.threadnum
    thread_pool = [None for _ in range(thread_num)]
    time_limit = 30 * 60
    memory_limit = args.memory * 1024 * 1024
    solver_list = solver_list_for_exp[args.experiment]
    total_benchmark = 0
    target_num = args.repeatnum

    benchmark_list = collect_all_benchmark()

    for benchmark in benchmark_list:
        if benchmark[0] not in result:
            result[benchmark[0]] = []
        total_benchmark += max(0, target_num - len(result[benchmark[0]]))
    for benchmark in benchmark_list:
        pre_size = max(0, target_num - len(result[benchmark[0]]))
        for i in range(pre_size):
            run_benchmark(benchmark[0], benchmark[1], time_limit)

    collect_all_result()
    os.system("killall main")

    if args.experiment == 0:
        for i in range(1, 5):
            draw(i)
    else:
        draw(args.experiment)
